# -*- coding: utf-8 -*-
"""Untitled27.ipynb

Automatically generated by Colab.

Original file is located at
    https://colab.research.google.com/drive/1B9SZoZrU-cBPcNbOTgAVuCVQKlpso2-R
"""

# -*- coding: utf-8 -*-
"""
Library for Matrix Product Approximation Algorithms and Bounds Comparison.

Contains functions for:
- Matrix generation (Uniform, Gaussian, Row Orthogonal, Repeated Cols, Nonlinear)
- Sampling algorithms (Uniform, RMM/Leverage Score, Deterministic)
- Custom algorithms (Greedy OMP, Gaussian Projection)
- Theoretical bounds calculation (User-provided and Standard)
- Experiment running logic
- Multi-panel plotting (NO LEGENDS)
"""

import numpy as np
import scipy
import scipy.linalg
import scipy.sparse
from scipy.sparse.linalg import svds
import time
import os
import matplotlib.pyplot as plt
import matplotlib.ticker as mticker
import pandas as pd
import warnings
import traceback
from typing import List, Tuple, Dict, Optional, Union, Any

# --- Global Warning Filters ---
warnings.filterwarnings("ignore", category=UserWarning, module="matplotlib")
warnings.filterwarnings("ignore", category=RuntimeWarning) # Broadly ignore runtime

# --- Global Font and Plot Styling (User Provided) ---
plt.rcParams.update({
    'font.size': 20,
    'axes.titlesize': 22,
    'axes.labelsize': 20,
    'xtick.labelsize': 18,
    'ytick.labelsize': 18,
    'legend.fontsize': 16,
    'figure.titlesize': 26,
    'figure.figsize': (18, 8), # Default for single plot, multi-panel adjusts
    'figure.dpi': 150,
    'savefig.dpi': 300,
    'lines.linewidth': 2.5,
    'lines.markersize': 10,
    'axes.linewidth': 1.5,
    'grid.linewidth': 1.0,
    'axes.grid': True,
    'grid.alpha': 0.3,
    'axes.titleweight': 'bold',
    'axes.labelweight': 'bold',
    'figure.titleweight': 'bold',
    'mathtext.default': 'regular',
    'mathtext.fontset': 'cm',
    'figure.facecolor': 'white',
    'axes.facecolor': 'white',
    'savefig.facecolor': 'white'
})
# --- End of Global Styling ---

# Optional dependency for QP Bound
try:
    import cvxpy as cp
    cvxpy_present = True
except ImportError:
    cvxpy_present = False
    warnings.warn("CVXPY not found. QP-based bounds (Your Bound (QP CVXPY Best)) will not be computed.", ImportWarning)

# ==============================================================================
# Configuration & Styling (Line Styles - User Provided IMPROVED_STYLES)
# ==============================================================================
IMPROVED_STYLES = {
    'Optimal Error v_k^*': {
        'color': 'gold', 'marker': '*', 'linestyle': '-', 'label': r'Optimal $v_k^*$',
        'lw': 4.0, 'markersize': 16, 'zorder': 10, 'markeredgewidth': 1.5, 'markeredgecolor': 'black'
    },
    'Your Bound (QP CVXPY Best)': {
        'color': 'black', 'marker': 'o', 'linestyle': '-', 'label': 'Bound (QP Best)',
        'lw': 3.5, 'markersize': 12, 'zorder': 9, 'markeredgewidth': 1.0
    },
    'Your Bound (QP Analytical)': {
        'color': 'dimgrey', 'marker': '^', 'linestyle': ':', 'label': 'Bound (QP Approx)',
        'lw': 3.0, 'markersize': 12, 'zorder': 8, 'markeredgewidth': 1.0
    },
    'Your Bound (Binary)': {
        'color': 'darkgrey', 'marker': 's', 'linestyle': '--', 'label': 'Bound (Binary)',
        'lw': 3.0, 'markersize': 12, 'zorder': 7, 'markeredgewidth': 1.0
    },
    'Bound (Leverage Score Exp.)': {
        'color': 'deepskyblue', 'marker': 'D', 'linestyle': '-.', 'label': 'Bound (Lev. Score Exp.)',
        'lw': 3.0, 'markersize': 12, 'zorder': 6, 'markeredgewidth': 1.0, 'markeredgecolor': 'navy'
    },
    'Bound (Sketching Simple)': {
        'color': 'sandybrown', 'marker': 'P', 'linestyle': ':', 'label': 'Bound (Sketching Simple)',
        'lw': 3.0, 'markersize': 12, 'zorder': 5, 'markeredgewidth': 1.0, 'markeredgecolor': 'saddlebrown'
    },
    'Error Leverage Score (Actual)': { # RMM Sampling
        'color': 'blue', 'marker': 'x', 'linestyle': '-', 'label': 'Leverage Score Sampling',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 4, 'markeredgewidth': 2.0
    },
    'Error CountSketch (Actual)': {
        'color': 'orange', 'marker': 'd', 'linestyle': '--', 'label': 'CountSketch',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 3, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkorange'
    },
    'Error SRHT (Actual)': {
        'color': 'red', 'marker': 'v', 'linestyle': '-.', 'label': 'SRHT',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 2, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkred'
    },
    'Error Gaussian (Actual)': {
        'color': 'darkviolet', 'marker': '<', 'linestyle': ':', 'label': 'Gaussian Proj.',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 1, 'markeredgewidth': 1.5, 'markeredgecolor': 'indigo'
    },
    'Error Greedy OMP (Actual)': {
        'color': 'forestgreen', 'marker': '>', 'linestyle': '-', 'label': 'Greedy OMP',
        'lw': 3.0, 'markersize': 14, 'alpha': 0.9, 'zorder': 0, 'markeredgewidth': 1.5, 'markeredgecolor': 'darkgreen'
    },
    'Uniform': {'color': 'crimson', 'marker': 'X', 'linestyle': '-', 'lw': 3.0, 'markersize': 12, 'label': r'Uniform Sampling', 'zorder': 3.5, 'alpha': 0.9, 'markeredgewidth': 1.5},
    'Deterministic': {'color': 'chocolate', 'marker': 'p', 'linestyle': ':', 'lw': 3.0, 'markersize': 12, 'label': r'Deterministic Sampling', 'zorder': 3.3, 'alpha': 0.9, 'markeredgewidth': 1.5},
    'Our Bound (v_k Ratio)': {
        'color': '#555555', 'marker': 'h', 'linestyle': (0, (3,2,1,2)), 'label': r'Our Bound ($v_k$ Ratio)',
        'lw': 3.0, 'markersize': 12, 'zorder': 7.5, 'markeredgewidth': 1.0
        },
}
FALLBACK_STYLE = {'color': 'purple', 'marker': '.', 'linestyle': '-', 'lw': 1.5, 'markersize': 6, 'label': 'Unknown Data', 'zorder': -1}

# ==============================================================================
# Utility Functions
# ==============================================================================
def sanitize_filename(name: str) -> str:
    sanitized = "".join(c for c in name if c.isalnum() or c in ('_', '-', '.'))
    sanitized = sanitized.replace(' ', '_')
    sanitized = '_'.join(filter(None, sanitized.split('_')))
    sanitized = '-'.join(filter(None, sanitized.split('-')))
    return sanitized

def safe_norm(matrix: Union[np.ndarray, scipy.sparse.spmatrix], norm_type: str = 'fro') -> float:
    try:
        if scipy.sparse.issparse(matrix):
            if norm_type == 'fro': return scipy.sparse.linalg.norm(matrix, ord='fro')
            if matrix.shape[0] * matrix.shape[1] == 0: return 0.0
            try: dense_matrix = matrix.toarray()
            except MemoryError: warnings.warn("MemoryError converting sparse to dense for norm. NaN.", UserWarning); return np.nan
            return np.linalg.norm(dense_matrix, ord=norm_type)
        else:
            if matrix.size == 0 : return 0.0
            return np.linalg.norm(matrix, ord=norm_type)
    except MemoryError: warnings.warn("MemoryError in norm. NaN.", UserWarning); return np.nan
    except Exception as e: warnings.warn(f"Error in norm: {e}. NaN.", UserWarning); return np.nan

def frob_norm_sq(matrix: np.ndarray) -> float:
    return safe_norm(matrix, 'fro')**2

def safe_svd(matrix: Union[np.ndarray, scipy.sparse.spmatrix], k: Optional[int] = None) -> Tuple[Optional[np.ndarray], Optional[np.ndarray], Optional[np.ndarray]]:
    try:
        if scipy.sparse.issparse(matrix):
            eff_k = min(k if k is not None else min(matrix.shape) - 1, min(matrix.shape) - 1)
            if eff_k <= 0: warnings.warn(f"Sparse SVD k={eff_k} invalid. None.", UserWarning); return None, None, None
            try:
                U, s, Vt = svds(matrix, k=eff_k)
                U = U[:, ::-1]; s = s[::-1]; Vt = Vt[::-1, :]; return U, s, Vt
            except Exception as e: warnings.warn(f"Sparse SVD failed: {e}. None.", UserWarning); return None, None, None
        else: U, s, Vt = np.linalg.svd(matrix, full_matrices=False); return U, s, Vt
    except np.linalg.LinAlgError as e: warnings.warn(f"SVD failed: {e}. None.", UserWarning); return None, None, None
    except MemoryError: warnings.warn("MemoryError in SVD. None.", UserWarning); return None, None, None
    except Exception as e: warnings.warn(f"Unexpected SVD error: {e}. None.", UserWarning); return None, None, None

# ==============================================================================
# Rho_G Calculation
# ==============================================================================
def calculate_rho_g(A: np.ndarray, B: np.ndarray) -> float:
    try:
        if A.shape[1] != B.shape[1]: raise ValueError(f"Dim mismatch Rho_G: A {A.shape[1]}, B {B.shape[1]} cols.")
        A_f64 = np.asarray(A, dtype=np.float64); B_f64 = np.asarray(B, dtype=np.float64)
        if A_f64.shape[1] == 0: return 0.0
        AtA = A_f64.T @ A_f64; BtB = B_f64.T @ B_f64
        G_hadamard = AtA * BtB
        trace_G = np.trace(G_hadamard); sum_G = np.sum(G_hadamard)
        if abs(sum_G) <= 1e-12:
             if np.linalg.norm(G_hadamard, 'fro') < 1e-12: return 0.0
             if abs(trace_G) > 1e-12: return np.inf # Or some large number indicating dominance of diagonal
             return 0.0 # Or handle as undefined/problematic
        rho = trace_G / sum_G; return max(0, rho) # Ensure non-negative
    except Exception as e: warnings.warn(f"Error in Rho_G (User): {e}"); return np.nan

# ==============================================================================
# Matrix Generation Functions (Selected 5 + 1 variant = 6 entries in dict)
# ==============================================================================
def generate_matrices_gaussian_cancellation(m_rows_A: int, p_rows_B: int, n_cols_common: int, cancel_fraction: float = 0.0, noise_level: float = 0.0, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    if seed is not None: np.random.seed(seed)
    A = np.random.randn(m_rows_A, n_cols_common); B = np.random.randn(p_rows_B, n_cols_common)
    num_cancel_cols = int(cancel_fraction * n_cols_common)
    if num_cancel_cols > 0 and n_cols_common > 1:
        cancel_indices = np.random.choice(n_cols_common, num_cancel_cols, replace=False)
        # Ensure cancellation happens on compatible parts if m != p
        min_rows_for_cancel = min(m_rows_A, p_rows_B)
        if m_rows_A == p_rows_B: # If A and B have same number of rows, direct cancellation
            B[:, cancel_indices] = -A[:, cancel_indices]
        else: # Otherwise, cancel only the common part
            B[:min_rows_for_cancel, cancel_indices] = -A[:min_rows_for_cancel, cancel_indices]
            # For the remaining rows of the larger matrix, one might zero them out or use original B values
            # For simplicity, this example just cancels the common part.
    if noise_level > 0:
        A += np.random.normal(0, noise_level, size=A.shape); B += np.random.normal(0, noise_level, size=B.shape)
    return A, B

def generate_matrices_uniform(m_rows_A: int, p_rows_B: int, n_cols_common: int, low: float = -1.0, high: float = 1.0, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    if seed is not None: np.random.seed(seed)
    A = np.random.uniform(low, high, size=(m_rows_A, n_cols_common)); B = np.random.uniform(low, high, size=(p_rows_B, n_cols_common))
    return A, B

def generate_matrices_repeated_cols(m_rows_A: int, p_rows_B: int, n_cols_common: int, repeat_frac: float = 0.1, noise_ratio: float = 0.01, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    if seed is not None: np.random.seed(seed)
    A_base = np.random.randn(m_rows_A, n_cols_common); B_base = np.random.randn(p_rows_B, n_cols_common)
    A, B = A_base.copy(), B_base.copy()
    num_unique_cols = max(1, int(n_cols_common * (1 - repeat_frac))); num_repeat_cols = n_cols_common - num_unique_cols
    if num_repeat_cols > 0 and num_unique_cols > 0 and num_unique_cols < n_cols_common:
        all_indices = np.arange(n_cols_common); unique_indices = np.random.choice(all_indices, num_unique_cols, replace=False)
        repeat_target_indices = np.setdiff1d(all_indices, unique_indices); source_for_repetition = np.random.choice(unique_indices, len(repeat_target_indices), replace=True)
        A[:, repeat_target_indices] = A_base[:, source_for_repetition]; B[:, repeat_target_indices] = B_base[:, source_for_repetition]
    if noise_ratio > 0:
        noise_A = np.random.randn(*A.shape); noise_B = np.random.randn(*B.shape)
        A += noise_A * (safe_norm(A) / (safe_norm(noise_A) + 1e-9)) * noise_ratio
        B += noise_B * (safe_norm(B) / (safe_norm(noise_B) + 1e-9)) * noise_ratio
    return A, B

def generate_matrices_nonlinear(m_rows_A: int, p_rows_B: int, n_cols_common: int, base_type: str = 'gaussian', func=np.tanh, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    gen_func = generate_matrices_gaussian_cancellation if base_type == 'gaussian' else generate_matrices_uniform
    A_base, B_base = gen_func(m_rows_A, p_rows_B, n_cols_common, seed=seed) # Pass relevant params if uniform is chosen
    A = func(A_base); B = func(B_base)
    return A, B

def generate_matrices_row_orthogonal(m_rows_A: int, p_rows_B: int, n_cols_common: int, seed: Optional[int] = None) -> Tuple[np.ndarray, np.ndarray]:
    if n_cols_common < m_rows_A or n_cols_common < p_rows_B:
        warnings.warn(f"Row orthogonal needs n_cols >= m_rows, p_rows. Gaussian fallback.", UserWarning)
        return generate_matrices_gaussian_cancellation(m_rows_A, p_rows_B, n_cols_common, 0.0, 0.0, seed=seed)
    if seed is not None: np.random.seed(seed)
    Q_A, _ = np.linalg.qr(np.random.randn(n_cols_common, m_rows_A)); A = Q_A.T
    if seed is not None: np.random.seed(seed + 1) # Ensure different random matrix for B
    Q_B, _ = np.linalg.qr(np.random.randn(n_cols_common, p_rows_B)); B = Q_B.T
    return A, B

# ==============================================================================
# Core Sampling Algorithms (Standard)
# ==============================================================================
def uniform_sampling(A, B, k, seed=None):
    if seed is not None: np.random.seed(seed)
    n = A.shape[1]
    k = max(1, min(k, n))
    if k == 0: return A[:,:0], B[:,:0], np.array([], dtype=int)
    idx = np.random.choice(n, k, replace=False)
    s = np.sqrt(n / k) if k > 0 else 1.0
    C = A[:, idx] * s if not scipy.sparse.issparse(A) else A[:, idx].multiply(s)
    W = B[:, idx] * s if not scipy.sparse.issparse(B) else B[:, idx].multiply(s)
    return C, W, idx

def rmm_sampling(A, B, k, seed=None): # Leverage Score Sampling
    if seed is not None: np.random.seed(seed)
    n = A.shape[1]
    k = max(1, min(k, n))
    if k == 0: return A[:,:0], B[:,:0], np.array([], dtype=int)

    nA_sq = np.array(A.power(2).sum(axis=0)).flatten() if scipy.sparse.issparse(A) else np.sum(A**2, axis=0)
    nB_sq = np.array(B.power(2).sum(axis=0)).flatten() if scipy.sparse.issparse(B) else np.sum(B**2, axis=0)
    g_sq = nA_sq * nB_sq
    sum_g_sq = np.sum(g_sq)

    if sum_g_sq < 1e-20: return uniform_sampling(A, B, k, seed)
    p = np.maximum(g_sq / sum_g_sq, 0); p /= np.sum(p)

    try: idx = np.random.choice(n, k, replace=False, p=p)
    except ValueError: return uniform_sampling(A, B, k, seed)

    s_vals = 1.0 / np.sqrt(np.maximum(k * p[idx], 1e-20))
    C = A[:, idx] @ scipy.sparse.diags(s_vals) if scipy.sparse.issparse(A) else A[:, idx] * s_vals
    W = B[:, idx] @ scipy.sparse.diags(s_vals) if scipy.sparse.issparse(B) else B[:, idx] * s_vals
    return C, W, idx

def deterministic_sampling(A, B, k):
    n = A.shape[1]
    k = max(1, min(k, n))
    if k == 0: return A[:,:0], B[:,:0], np.array([], dtype=int)

    nA_sq = np.array(A.power(2).sum(axis=0)).flatten() if scipy.sparse.issparse(A) else np.sum(A**2, axis=0)
    nB_sq = np.array(B.power(2).sum(axis=0)).flatten() if scipy.sparse.issparse(B) else np.sum(B**2, axis=0)
    g_sq = nA_sq * nB_sq

    if k >= n: idx = np.arange(n)
    elif k < n / 2 : idx = np.argpartition(g_sq, -k)[-k:]
    else: idx = np.argsort(g_sq)[-k:]

    C = A[:, idx]; W = B[:, idx]
    return C, W, idx

# ==============================================================================
# USER-PROVIDED ACCURATE ALGORITHMS
# ==============================================================================
def run_greedy_selection_omp(A: np.ndarray, B: np.ndarray, k: int, ABt_exact: Optional[np.ndarray] = None) -> np.ndarray:
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch for OMP")
    if not (0 <= k <= n): raise ValueError(f"k={k} must be between 0 and n={n} for Greedy OMP")
    if k == 0 or n == 0: return np.zeros((m, p), dtype=A.dtype)

    if ABt_exact is None: ABt_exact = A @ B.T

    selected_indices = []
    remaining_indices = list(range(n))
    residual = ABt_exact.astype(np.float64).copy()
    A_f64 = A.astype(np.float64); B_f64 = B.astype(np.float64)

    for _ in range(k):
        if not remaining_indices: break
        best_corr = -1.0; best_rem_idx_in_list = -1
        for list_idx, original_col_idx in enumerate(remaining_indices):
             outer_prod_j = np.outer(A_f64[:, original_col_idx], B_f64[:, original_col_idx])
             correlation = np.sum(residual * outer_prod_j)
             abs_correlation = np.abs(correlation)
             if abs_correlation > best_corr:
                 best_corr = abs_correlation; best_rem_idx_in_list = list_idx
        if best_rem_idx_in_list == -1: break
        best_original_idx = remaining_indices.pop(best_rem_idx_in_list)
        selected_indices.append(best_original_idx)
        A_selected = A_f64[:, selected_indices]; B_selected = B_f64[:, selected_indices]
        current_approx = A_selected @ B_selected.T
        residual = ABt_exact - current_approx
    if not selected_indices: return np.zeros((m, p), dtype=A.dtype)
    A_final = A[:, selected_indices]; B_final = B[:, selected_indices]
    return A_final @ B_final.T

def run_gaussian_projection(A: np.ndarray, B: np.ndarray, k: int) -> np.ndarray:
    m, n = A.shape; p, n2 = B.shape
    if n != n2: raise ValueError("Dimension mismatch for Gaussian Projection")
    if k <= 0: return np.zeros((m,p), dtype=A.dtype)
    if n == 0: return np.zeros((m, p), dtype=A.dtype)
    S = np.random.randn(k, n) / np.sqrt(k)
    A_proj = A @ S.T; B_proj = B @ S.T
    return A_proj @ B_proj.T

# ==============================================================================
# USER-PROVIDED AND MODIFIED THEORETICAL BOUNDS CALCULATION
# ==============================================================================
def compute_theoretical_bounds(data: Dict[str, Any], k: int) -> Dict[str, Any]:
    n = data['n']; frob_norm_AB = data['frob_norm']
    binary_ratio, qp_ratio, vk_ratio, Greedy_bound_ratio = np.nan, np.nan, np.nan, np.nan
    has_qp_data = all(key in data for key in ['Gab', 'q', 'r', 'trace'])

    if has_qp_data:
        G = data['Gab']; q_vec = data['q']; r_val = data['r']; TrG = data['trace']
        oneGone = frob_norm_AB**2
        if n > 1 and oneGone > 1e-12:
            rho_G = TrG / oneGone if oneGone > 1e-12 else 0.0
            beta_k = (k - 1) / (n - 1) if n > 1 else 0.0
            denominator = (beta_k + (1 - beta_k) * rho_G)
            if abs(denominator) > 1e-12:
                gamma = 1.0 / denominator
                qp_bound_sq_ratio = max(0, 1.0 - k * gamma / n)
                qp_ratio = np.sqrt(qp_bound_sq_ratio)
            alpha_k = k / (n - 1) if n > 1 else 0.0
            binary_bound_sq = max(0, (1.0 - k * 1.0 / n) * ((1.0 - alpha_k) * oneGone + alpha_k * TrG))
            binary_ratio = np.sqrt(binary_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0
            if k > 0 and cvxpy_present:
                G_hat_k = beta_k * G + (1 - beta_k) * np.diag(np.diag(G))
                y = cp.Variable(n); constraints = [y >= 0]
                objective = cp.Minimize(0.5 * cp.quad_form(y, G_hat_k) - q_vec.T @ y)
                prob = cp.Problem(objective, constraints)
                try:
                    prob.solve(solver=cp.SCS, verbose=False, eps=1e-7, max_iters=10000)
                    if prob.status in [cp.OPTIMAL, cp.OPTIMAL_INACCURATE]:
                        v_k_bound_sq = max(0, oneGone + (k / n) * 2.0 * prob.value)
                        vk_ratio = np.sqrt(v_k_bound_sq) / frob_norm_AB if frob_norm_AB > 1e-12 else 0.0
                except (cp.SolverError, Exception): pass
            elif k > 0 and not cvxpy_present: vk_ratio = np.nan
            elif k == 0: vk_ratio = 1.0 if oneGone > 1e-12 else 0.0
    if 'A' in data and 'B' in data: # B is n_common x p
        A_data, B_data_transposed = data['A'], data['B']
        if A_data.shape[1] == n and B_data_transposed.shape[0] == n:
            if not 0 <= k <= n: Greedy_bound_ratio = np.nan
            elif k == 0: Greedy_bound_ratio = 1.0 if frob_norm_AB > 1e-12 else 0.0
            elif k == n: Greedy_bound_ratio = 0.0
            else:
                norms_A_sq = np.sum(A_data * A_data, axis=0)
                norms_B_sq = np.sum(B_data_transposed * B_data_transposed, axis=1) # B_data_transposed is n x p, sum over p
                T_vals = norms_A_sq * norms_B_sq
                sum_T_complement = np.sum(np.sort(T_vals)[:n-k]) if n-k > 0 else 0.0
                Greedy_bound_val_sq = max(0, sum_T_complement)
                Greedy_bound_val = np.sqrt(Greedy_bound_val_sq)
                Greedy_bound_ratio = Greedy_bound_val / frob_norm_AB if frob_norm_AB > 1e-12 else (0.0 if Greedy_bound_val < 1e-12 else np.inf)
    return {'binary_ratio': binary_ratio, 'qp_ratio': qp_ratio, 'qp_ratio_best': vk_ratio, 'Greedy_ratio': Greedy_bound_ratio}

def compute_standard_bounds(A: np.ndarray, B: np.ndarray, k: int, frob_ABt_sq: float) -> Dict[str, float]:
    m, n_common_A = A.shape; p, n_common_B = B.shape # B is P x n_common
    bound_leverage_exp_ratio, bound_sketching_simple_ratio = np.nan, np.nan
    if n_common_A != n_common_B or n_common_A == 0 or frob_ABt_sq < 1e-20:
        if k == 0 and frob_ABt_sq > 1e-20:
            bound_leverage_exp_ratio = 1.0; bound_sketching_simple_ratio = np.inf
        return {'Bound (Leverage Score Exp.)': bound_leverage_exp_ratio, 'Bound (Sketching Simple)': bound_sketching_simple_ratio}
    if k == 0: return {'Bound (Leverage Score Exp.)': 1.0, 'Bound (Sketching Simple)': np.inf}
    A_f64, B_f64 = A.astype(np.float64), B.astype(np.float64)
    try:
        norms_A, norms_B = np.linalg.norm(A_f64, axis=0), np.linalg.norm(B_f64, axis=0)
        sum_prod_norms = max(0, np.sum(norms_A * norms_B))
        expected_error_sq_abs = (sum_prod_norms**2 - frob_ABt_sq) / k
        bound_leverage_exp_sq_ratio = max(0, expected_error_sq_abs / frob_ABt_sq) if frob_ABt_sq > 1e-20 else 0.0
        if pd.notna(bound_leverage_exp_sq_ratio) and bound_leverage_exp_sq_ratio >=0:
            bound_leverage_exp_ratio = np.sqrt(bound_leverage_exp_sq_ratio)
    except Exception as e: warnings.warn(f"Lev. Score bound k={k}: {e}", RuntimeWarning)
    try:
        frob_A_sq, frob_B_sq = frob_norm_sq(A_f64), frob_norm_sq(B_f64)
        sketching_bound_sq_ratio = max(0, (frob_A_sq * frob_B_sq) / (k * frob_ABt_sq)) if k * frob_ABt_sq > 1e-20 else np.inf
        if pd.notna(sketching_bound_sq_ratio) and sketching_bound_sq_ratio >=0:
            bound_sketching_simple_ratio = np.sqrt(sketching_bound_sq_ratio)
    except Exception as e: warnings.warn(f"Sketching bound k={k}: {e}", RuntimeWarning)
    return {'Bound (Leverage Score Exp.)': bound_leverage_exp_ratio, 'Bound (Sketching Simple)': bound_sketching_simple_ratio}

def compute_all_bounds_orchestrator(A_orig_mat, B_orig_mat, k_values_list, frob_ABt_sq_val, common_dim_n) -> Dict[str, np.ndarray]:
    num_k = len(k_values_list)
    results = {
        'Your Bound (Binary)': np.full(num_k, np.nan), 'Your Bound (QP Analytical)': np.full(num_k, np.nan),
        'Your Bound (QP CVXPY Best)': np.full(num_k, np.nan), 'Bound (Leverage Score Exp.)': np.full(num_k, np.nan),
        'Bound (Sketching Simple)': np.full(num_k, np.nan), 'k_values_bounds': np.array(k_values_list)
    }
    if pd.isna(frob_ABt_sq_val) or frob_ABt_sq_val < 1e-20:
        warnings.warn("Skipping bound calculations due to near-zero ||AB^T||_F^2.", UserWarning); return results
    A_dense = A_orig_mat.toarray() if scipy.sparse.issparse(A_orig_mat) else A_orig_mat.astype(np.float64)
    B_dense = B_orig_mat.toarray() if scipy.sparse.issparse(B_orig_mat) else B_orig_mat.astype(np.float64) # B_dense is P x n_common
    data_for_custom_bounds = {'n': common_dim_n, 'frob_norm': np.sqrt(frob_ABt_sq_val)}
    custom_bounds_prepped = False
    try:
        AtA = A_dense.T @ A_dense
        BtB_custom = B_dense.T @ B_dense # n_common x n_common
        Gab = AtA * BtB_custom
        data_for_custom_bounds['Gab'] = Gab; data_for_custom_bounds['q'] = np.diag(Gab)
        data_for_custom_bounds['r'] = np.sum(Gab); data_for_custom_bounds['trace'] = np.trace(Gab)
        data_for_custom_bounds['A'] = A_dense; data_for_custom_bounds['B'] = B_dense.T # For Greedy_ratio: B is n_common x P
        custom_bounds_prepped = True
    except Exception as e: warnings.warn(f"Failed to precompute Gab for custom bounds: {e}", UserWarning)

    for i, k_iter in enumerate(k_values_list):
        if custom_bounds_prepped:
            user_b_k = compute_theoretical_bounds(data_for_custom_bounds, k_iter)
            results['Your Bound (Binary)'][i] = user_b_k.get('binary_ratio', np.nan)
            results['Your Bound (QP Analytical)'][i] = user_b_k.get('qp_ratio', np.nan)
            results['Your Bound (QP CVXPY Best)'][i] = user_b_k.get('qp_ratio_best', np.nan)
        standard_b_k = compute_standard_bounds(A_dense, B_dense, k_iter, frob_ABt_sq_val) # B_dense is P x n_common
        results['Bound (Leverage Score Exp.)'][i] = standard_b_k.get('Bound (Leverage Score Exp.)', np.nan)
        results['Bound (Sketching Simple)'][i] = standard_b_k.get('Bound (Sketching Simple)', np.nan)
    return results

# ==============================================================================
# Experiment Runner
# ==============================================================================
def run_algorithm_experiments(A_orig, B_orig, k_values_list, n_trials_per_k=5, main_seed_exp_runner=42):
    results_algo = {
        'Uniform': np.full(len(k_values_list), np.nan), 'Error Leverage Score (Actual)': np.full(len(k_values_list), np.nan),
        'Deterministic': np.full(len(k_values_list), np.nan), 'Error Gaussian (Actual)': np.full(len(k_values_list), np.nan),
        'Error Greedy OMP (Actual)': np.full(len(k_values_list), np.nan), 'k': np.array(k_values_list)
    }
    m_rows_A, n_cols_common = A_orig.shape; p_rows_B, _ = B_orig.shape
    if n_cols_common == 0 : print("Warning: Empty common dimension."); return results_algo, {'frob_ABt_sq': np.nan}
    try:
        B_T_mat = B_orig.T.tocsr() if scipy.sparse.isspmatrix(B_orig) else B_orig.T
        G_exact = A_orig @ B_T_mat
        frob_G_sq = safe_norm(G_exact, 'fro')**2
        if pd.isna(frob_G_sq): print("FATAL: Norm of A@B.T failed."); return results_algo, {'frob_ABt_sq': np.nan}
        if frob_G_sq < 1e-20: print("Warning: A@B.T Frobenius norm is near zero.")
    except Exception as e: print(f"Error A@B.T: {e}"); return results_algo, {'frob_ABt_sq': np.nan}

    A_dense_for_algs = A_orig.toarray() if scipy.sparse.issparse(A_orig) else A_orig.copy()
    B_dense_for_algs = B_orig.toarray() if scipy.sparse.issparse(B_orig) else B_orig.copy() # P x n_common

    for i_k, k_val in enumerate(k_values_list):
        try:
            C_d, W_d, _ = deterministic_sampling(A_orig, B_orig, k_val) # B_orig is P x n_common
            approx_G_d = C_d @ W_d.T if C_d.shape[1] > 0 else (scipy.sparse.csr_matrix((m_rows_A,p_rows_B)) if scipy.sparse.issparse(A_orig) else np.zeros((m_rows_A,p_rows_B)))
            err_d_abs = safe_norm(G_exact - approx_G_d, 'fro')
            results_algo['Deterministic'][i_k] = err_d_abs / np.sqrt(frob_G_sq) if frob_G_sq > 1e-20 else (0.0 if err_d_abs < 1e-10 else np.nan)
        except Exception: results_algo['Deterministic'][i_k] = np.nan
        try: # A_dense_for_algs is N x n_common, B_dense_for_algs is P x n_common
            approx_G_omp = run_greedy_selection_omp(A_dense_for_algs, B_dense_for_algs, k_val, ABt_exact=G_exact)
            err_omp_abs = safe_norm(G_exact - approx_G_omp, 'fro')
            results_algo['Error Greedy OMP (Actual)'][i_k] = err_omp_abs / np.sqrt(frob_G_sq) if frob_G_sq > 1e-20 else (0.0 if err_omp_abs < 1e-10 else np.nan)
        except Exception: results_algo['Error Greedy OMP (Actual)'][i_k] = np.nan

        trial_errors = {'Uniform': [], 'Error Leverage Score (Actual)': [], 'Error Gaussian (Actual)': []}
        for trial in range(n_trials_per_k):
            current_trial_seed = (main_seed_exp_runner * (i_k + 1) * (trial + 1)) % (2**32 - 1)
            try:
                C_u, W_u, _ = uniform_sampling(A_orig, B_orig, k_val, seed=current_trial_seed)
                approx_G_u = C_u @ W_u.T if C_u.shape[1] > 0 else (scipy.sparse.csr_matrix((m_rows_A,p_rows_B)) if scipy.sparse.issparse(A_orig) else np.zeros((m_rows_A,p_rows_B)))
                trial_errors['Uniform'].append(safe_norm(G_exact - approx_G_u, 'fro'))
            except Exception: trial_errors['Uniform'].append(np.nan)
            try:
                C_rmm, W_rmm, _ = rmm_sampling(A_orig, B_orig, k_val, seed=current_trial_seed + 1)
                approx_G_rmm = C_rmm @ W_rmm.T if C_rmm.shape[1] > 0 else (scipy.sparse.csr_matrix((m_rows_A,p_rows_B)) if scipy.sparse.issparse(A_orig) else np.zeros((m_rows_A,p_rows_B)))
                trial_errors['Error Leverage Score (Actual)'].append(safe_norm(G_exact - approx_G_rmm, 'fro'))
            except Exception: trial_errors['Error Leverage Score (Actual)'].append(np.nan)
            try:
                np.random.seed(current_trial_seed + 2) # A_dense_for_algs (N x n_common), B_dense_for_algs (P x n_common)
                approx_G_gauss = run_gaussian_projection(A_dense_for_algs, B_dense_for_algs, k_val)
                trial_errors['Error Gaussian (Actual)'].append(safe_norm(G_exact - approx_G_gauss, 'fro'))
            except Exception: trial_errors['Error Gaussian (Actual)'].append(np.nan)
        with warnings.catch_warnings():
            warnings.simplefilter("ignore", category=RuntimeWarning)
            for algo_name, errors_list in trial_errors.items():
                if errors_list:
                    mean_err_abs = np.nanmean(errors_list)
                    results_algo[algo_name][i_k] = mean_err_abs / np.sqrt(frob_G_sq) if frob_G_sq > 1e-20 and pd.notna(mean_err_abs) else (0.0 if pd.notna(mean_err_abs) and mean_err_abs < 1e-10 else np.nan)
                else: results_algo[algo_name][i_k] = np.nan
    return results_algo, {'frob_ABt_sq': frob_G_sq}

# ==============================================================================
# Multi-Panel Plotting Function (NO LEGENDS)
# ==============================================================================
def plot_multi_panel_results(
    results_list: List[Dict[str, np.ndarray]], matrix_type_names: List[str],
    figure_title_prefix: str, common_dim_for_k_ratio: int,
    styles_dict: Dict[str, Dict] = IMPROVED_STYLES, plot_dir_path: str = "plots",
    log_scale_y_axis: bool = True
):
    num_subplots_total = len(results_list)
    if num_subplots_total == 0: print(f"No results to plot for {figure_title_prefix}."); return

    # Determine number of rows and columns for subplots
    if num_subplots_total <= 3:
        num_rows, num_cols = 1, num_subplots_total
    elif num_subplots_total <= 6:
        num_rows, num_cols = 2, 3
    else: # Fallback for more than 6, though current setup has max 6
        num_rows, num_cols = (num_subplots_total + 2) // 3, 3


    fig, axes = plt.subplots(num_rows, num_cols, figsize=(8 * num_cols, 6.5 * num_rows), squeeze=False)
    axes = axes.flatten(); fig.patch.set_facecolor('white')

    for i in range(num_subplots_total):
        ax_current = axes[i]; ax_current.set_facecolor('white')
        results_data_dict = results_list[i]; matrix_type_name_str = matrix_type_names[i]
        k_absolute_values = results_data_dict.get('k', results_data_dict.get('k_values_bounds'))
        if k_absolute_values is None or len(k_absolute_values) == 0:
            ax_current.set_title(f"{matrix_type_name_str}\n(No k-data)"); continue
        k_axis_values = k_absolute_values / common_dim_for_k_ratio if common_dim_for_k_ratio > 0 else k_absolute_values
        rho_g_metric_val = results_data_dict.get('Rho_G', np.nan)
        subplot_title = f"{matrix_type_name_str.replace('_', ' ')}"
        if pd.notna(rho_g_metric_val): subplot_title += f"\n($\\rho_G \\approx {rho_g_metric_val:.2f}$)"
        ax_current.set_title(subplot_title)
        ax_current.set_xlabel("Sparsity Ratio (k / common_dim)")
        if i % num_cols == 0: ax_current.set_ylabel("Relative Error / Bound Ratio") # Y-label only for first col

        lines_plotted_this_subplot = 0; all_plot_vals_for_ylim = []
        for data_key, y_values_array in results_data_dict.items():
            if data_key.lower() in ['k', 'rho_g', 'g_exact_abt', 'a_orig', 'b_orig', 'frob_abt_sq', 'sampled_indices_deterministic', 'k_values_bounds', 'k_values']: continue
            if not isinstance(y_values_array, (np.ndarray, list)) or len(y_values_array) != len(k_axis_values):
                warnings.warn(f"Data '{data_key}' in '{matrix_type_name_str}' invalid. Skip.", UserWarning); continue
            style_to_apply = styles_dict.get(data_key, FALLBACK_STYLE)
            if data_key not in styles_dict: print(f"  Plot Warn: Style for '{data_key}' not in IMPROVED_STYLES. Fallback.")
            valid_data_indices = ~np.isnan(y_values_array) & np.isfinite(y_values_array)
            if log_scale_y_axis: valid_data_indices &= (y_values_array > 1e-9)
            if np.any(valid_data_indices):
                all_plot_vals_for_ylim.extend(y_values_array[valid_data_indices])
                ax_current.plot(k_axis_values[valid_data_indices], y_values_array[valid_data_indices], **{k_sty:v_sty for k_sty,v_sty in style_to_apply.items() if k_sty != 'label'})
                lines_plotted_this_subplot +=1
        if lines_plotted_this_subplot == 0: ax_current.text(0.5, 0.5, "No valid data", ha='center', va='center', transform=ax_current.transAxes)
        if all_plot_vals_for_ylim:
            min_val, max_val = np.min(all_plot_vals_for_ylim), np.max(all_plot_vals_for_ylim)
            if log_scale_y_axis:
                bottom_lim = max(min_val * 0.1, 1e-8) if min_val > 1e-9 else 1e-8
                top_lim = max(max_val * 10, bottom_lim * 100); ax_current.set_ylim(bottom=bottom_lim, top=top_lim); ax_current.set_yscale('log')
            else: padding = (max_val - min_val) * 0.1 if (max_val - min_val) > 1e-6 else 0.1; ax_current.set_ylim(bottom=min_val - padding, top=max_val + padding)
        elif log_scale_y_axis: ax_current.set_ylim(bottom=1e-7, top=1.0); ax_current.set_yscale('log')
        else: ax_current.set_ylim(bottom=0, top=1.0)
        ax_current.spines['top'].set_visible(False); ax_current.spines['right'].set_visible(False)
        ax_current.grid(True, which="both", linestyle=":", linewidth=plt.rcParams['grid.linewidth'], alpha=plt.rcParams['grid.alpha'])

    for j in range(num_subplots_total, len(axes)): fig.delaxes(axes[j]) # Remove unused subplots
    plt.tight_layout(rect=[0, 0, 1, 0.96]) # Adjust layout to make space for suptitle if needed
    # fig.suptitle(figure_title_prefix.replace("_", " "), fontsize=plt.rcParams['figure.titlesize']) # Optional suptitle
    sanitized_fig_title = sanitize_filename(figure_title_prefix)
    plot_file_path = os.path.join(plot_dir_path, f"{sanitized_fig_title}_multi_panel_NO_LEGEND.png")
    try: plt.savefig(plot_file_path, facecolor='white', bbox_inches='tight', dpi=plt.rcParams['savefig.dpi']); print(f"Plot saved: {plot_file_path}")
    except Exception as e: print(f"Error saving plot {plot_file_path}: {e}")
    plt.show(); plt.close(fig)

